{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T09:59:41.848917Z",
     "start_time": "2020-03-29T09:59:41.843223Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 整体结构\n",
    "- 除了之前的全连接层`Affine`，新出现了卷积层(Convolution)和池化层(Pooling)\n",
    "- `Affine - ReLU` 变成了 `Conv - ReLU - (Pooling)`\n",
    "\n",
    "\n",
    "- 全连接层忽略了数据的形状，因此无法利用与形状相关的数据\n",
    "    - 如高、宽、通道的 3 维图像数据，输入全连接层时，会被处理成一列的形式\n",
    "    - 而图像数据的形状中含有重要的空间信息，如空间相邻的像素为相似的值，相距较远的像素没有什么关联，`RGB`各个通道之间有关联，三维形状中可能隐藏有本质模式\n",
    "- `CNN` 中保持形状不变。将卷积层的输入数据称为“输入特征图(`feature map`)”    ，输出数据称为“输出特征图”。\n",
    "- 卷积层识别出输入中的模式(如图像中的横向或纵向线条)，随着卷积层的叠加，可以逐渐识别出输入的高阶模式(如图像中的特定形状)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 卷积运算：滤波器(核) - 填充 - 步幅；\n",
    "    - 滤波器(特征探测器)沿着输入数据扫描，将滤波器元素和对应的输入数据累加求和\n",
    "   \n",
    "    - 输入大小$(H,W)$，滤波器大小$(FW,FH)$，填充为$P$，步幅为$S$，则输出为$OH=\\frac{H+2P-FH}{S}+1,OW=\\frac{W+2P-FW}{S}+1$\n",
    "               \n",
    "    - 多通道的输入数据，单个滤波器的通道数要与输入通道相同，每个通道的滤波器和对应通道的数据乘积累加后再求和，输入大小$(C,H,W)$，滤波器大小$(C,FW,FH)$，输出$(1,OH,OW)$；即输出是通道数为 1 的特征图\n",
    "                     \n",
    "    - 多个滤波器时，输出则为多通道的特征图；输入大小$(C,H,W)$，滤波器大小$(FN,C,FW,FH)$，滤波后输出$(FN,OH,OW)$，再加上偏置$(FN,1,1)$，最终输出$(FN,OH,OW)$\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "- 池化运算，池化是缩小高、长方向上空间的运算，通道数不发生变化\n",
    "    - 池化层对输入数据进行采样，减小计算量、参数数量及内存使用，抑制过拟合的风险\n",
    "    - 池化层从目标区域中取最大值或平均值，仅仅因此没有要学习的参数\n",
    "    \n",
    "    - 一般来说，池化窗口的大小会和步长设定成相同的值\n",
    "    - 输入数据发生微小偏差时，池化返回相同的结果。因此池化对微小的位置变化具有鲁棒性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 卷积层的实现\n",
    "\n",
    "## 输入数据预处理\n",
    "- 可以对`Numpy`格式输入数据使用多重`for`循环实现卷积运算，运行较慢；利用`Numpy`中的矩阵乘法实现，输入数据需要预处理\n",
    "![](../images/im2col.png)\n",
    "- `(A)`为3维的输入数据$(C,H,W)$\n",
    "- `(B)`将滤波器依次扫过的区域提取出来，$(C,FH,FW,OH,OW)$\n",
    "- `(C)`将每个小块的数据$(C,FH,FW)$展平成$C\\times FW\\times FH$，整个数据转换成二维$(OH\\times OW,C\\times FW\\times FH)$\n",
    "\n",
    "- 将滤波器$(FN,C,FW,FH)$纵向展开为$(C\\times FW\\times FH, FN)$；\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T13:44:30.422029Z",
     "start_time": "2020-03-29T13:44:30.411977Z"
    }
   },
   "outputs": [],
   "source": [
    "def im2col(input_data, filter_h, filter_w, stride=1, pad=0):\n",
    "    \"\"\"\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_data : 由(数据量, 通道, 高, 长)的4维数组构成的输入数据\n",
    "    filter_h : 滤波器的高\n",
    "    filter_w : 滤波器的长\n",
    "    stride : 步幅\n",
    "    pad : 填充\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    col : 2维数组(数据量*输出高*输出宽,滤波器高*滤波器宽*通道)\n",
    "    \"\"\"\n",
    "    N, C, H, W = input_data.shape\n",
    "    out_h = (H + 2 * pad - filter_h) // stride + 1\n",
    "    out_w = (W + 2 * pad - filter_w) // stride + 1\n",
    "    img = np.pad(input_data, [(0, 0), (0, 0), (pad, pad), (pad, pad)],\n",
    "                 'constant')\n",
    "    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))\n",
    "\n",
    "    for y in range(filter_h):\n",
    "        y_max = y + stride * out_h\n",
    "        for x in range(filter_w):\n",
    "            x_max = x + stride * out_w\n",
    "            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]\n",
    "    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_h * out_w, -1)\n",
    "    return col\n",
    "\n",
    "\n",
    "def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):\n",
    "    N, C, H, W = input_shape\n",
    "    out_h = (H + 2 * pad - filter_h) // stride + 1\n",
    "    out_w = (W + 2 * pad - filter_w) // stride + 1\n",
    "    col = col.reshape(N, out_h, out_w, C, filter_h,\n",
    "                      filter_w).transpose(0, 3, 4, 5, 1, 2)\n",
    "    img = np.zeros((N, C, H + 2 * pad + stride - 1, W + 2 * pad + stride - 1))\n",
    "    for y in range(filter_h):\n",
    "        y_max = y + stride * out_h\n",
    "        for x in range(filter_w):\n",
    "            x_max = x + stride * out_w\n",
    "            img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]\n",
    "    return img[:, :, pad:H + pad, pad:W + pad]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T13:48:25.938558Z",
     "start_time": "2020-03-29T13:48:25.931653Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[0.38201987 0.97114608 0.35504192 0.75298388 0.50798056]\n",
      "   [0.31585327 0.62507244 0.92727894 0.65885742 0.93872894]\n",
      "   [0.6375136  0.32163867 0.78459359 0.88176927 0.49661809]\n",
      "   [0.90116761 0.82551255 0.10971107 0.98956697 0.67861124]\n",
      "   [0.12791343 0.94203226 0.38407891 0.04116853 0.09322415]]]]\n",
      "[[0.38201987 0.97114608 0.35504192 0.31585327 0.62507244 0.92727894\n",
      "  0.6375136  0.32163867 0.78459359]\n",
      " [0.35504192 0.75298388 0.50798056 0.92727894 0.65885742 0.93872894\n",
      "  0.78459359 0.88176927 0.49661809]\n",
      " [0.6375136  0.32163867 0.78459359 0.90116761 0.82551255 0.10971107\n",
      "  0.12791343 0.94203226 0.38407891]\n",
      " [0.78459359 0.88176927 0.49661809 0.10971107 0.98956697 0.67861124\n",
      "  0.38407891 0.04116853 0.09322415]]\n"
     ]
    }
   ],
   "source": [
    "x = np.random.rand(1, 1, 5, 5)\n",
    "print(x)\n",
    "x = im2col(x, 3, 3, 2, 0)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 反向传播卷积层实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T13:50:26.978468Z",
     "start_time": "2020-03-29T13:50:26.968847Z"
    }
   },
   "outputs": [],
   "source": [
    "class Convolution:\n",
    "    def __init__(self, W, b, stride=1, pad=0):\n",
    "        self.W = W\n",
    "        self.b = b\n",
    "        self.stride = stride\n",
    "        self.pad = pad\n",
    "\n",
    "        self.x = None\n",
    "        self.col = None\n",
    "        self.col_W = None\n",
    "\n",
    "        self.dW = None\n",
    "        self.db = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        FN, C, FH, FW = self.W.shape\n",
    "        N, C, H, W = x.shape\n",
    "        out_h = 1 + int((H + 2 * self.pad - FH) / self.stride)\n",
    "        out_w = 1 + int((W + 2 * self.pad - FW) / self.stride)\n",
    "\n",
    "        col = im2col(x, FH, FW, self.stride, self.pad)\n",
    "        col_W = self.W.reshape(FN, -1).T\n",
    "\n",
    "        out = np.dot(col, col_W) + self.b\n",
    "        out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)\n",
    "\n",
    "        self.x = x\n",
    "        self.col = col\n",
    "        self.col_W = col_W\n",
    "        return out\n",
    "\n",
    "    def backward(self, dout):\n",
    "        FN, C, FH, FW = self.W.shape\n",
    "        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)\n",
    "        self.db = np.sum(dout, axis=0)\n",
    "        self.dW = np.dot(self.col.T, dout)\n",
    "        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)\n",
    "\n",
    "        dcol = np.dot(dout, self.col_W.T)\n",
    "        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)\n",
    "        return dx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T13:54:29.253586Z",
     "start_time": "2020-03-29T13:54:29.245923Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[0.47333855 0.30416613 0.34319674 0.73066228 0.43317542]\n",
      "   [0.69969408 0.85851103 0.0565631  0.1898052  0.23868948]\n",
      "   [0.66851131 0.57649214 0.94791586 0.03002559 0.00995395]\n",
      "   [0.63725434 0.07555025 0.76421214 0.64179982 0.53745217]\n",
      "   [0.80886616 0.68542613 0.48702812 0.74047263 0.17961829]]]]\n",
      "[[[[2.34473842 1.41175451]\n",
      "   [2.61711026 2.79379658]]]]\n"
     ]
    }
   ],
   "source": [
    "x = np.random.rand(1, 1, 5, 5)\n",
    "print(x)\n",
    "W = np.random.rand(1, 1, 3, 3)\n",
    "b = np.random.rand(1, 1, 1)\n",
    "conv = Convolution(W, b, 2, 0)\n",
    "x = conv.forward(x)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 池化层的实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T10:50:29.391422Z",
     "start_time": "2020-03-29T10:50:29.382710Z"
    }
   },
   "outputs": [],
   "source": [
    "class Pooling:\n",
    "    def __init__(self, pool_h, pool_w, stride=1, pad=0):\n",
    "        self.pool_h = pool_h\n",
    "        self.pool_w = pool_w\n",
    "        self.stride = stride\n",
    "        self.pad = pad\n",
    "\n",
    "    def forward(self, x):\n",
    "        N, C, H, W = x.shape\n",
    "        out_h = int(1 + (H - self.pool_h) / self.stride)\n",
    "        out_w = int(1 + (W - self.pool_w) / self.stride)\n",
    "        col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)\n",
    "        col = col.reshape(-1, self.pool_h * self.pool_w)\n",
    "\n",
    "        arg_max = np.argmax(col, axis=1)\n",
    "        out = np.max(col, axis=1)\n",
    "        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)\n",
    "        self.x = x\n",
    "        self.arg_max = arg_max\n",
    "        return out\n",
    "\n",
    "    def backward(self, dout):\n",
    "        dout = dout.transpose(0, 2, 3, 1)\n",
    "        pool_size = self.pool_h * self.pool_w\n",
    "        dmax = np.zeros((dout, pool_size))\n",
    "        dmax[np.arange(self.arg_max.size\n",
    "                       ), self.arg_max.flatten()] = dout.flatten()\n",
    "        dmax = dmax.reshape(dout.shape + (pool_size, ))\n",
    "        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)\n",
    "        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride,\n",
    "                    self.pad)\n",
    "        return dx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T13:56:32.745982Z",
     "start_time": "2020-03-29T13:56:32.739895Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[0.35381993 0.60204659 0.9236718  0.85403515 0.51023189]\n",
      "   [0.06734814 0.68786701 0.79496244 0.61514594 0.76761991]\n",
      "   [0.85729933 0.87327673 0.08466223 0.68186728 0.95831023]\n",
      "   [0.09123023 0.08212613 0.21043752 0.49848096 0.08235595]\n",
      "   [0.74364385 0.14449196 0.3734369  0.08876523 0.63836527]]]]\n",
      "[[[[0.9236718  0.95831023]\n",
      "   [0.87327673 0.95831023]]]]\n"
     ]
    }
   ],
   "source": [
    "x = np.random.rand(1, 1, 5, 5)\n",
    "print(x)\n",
    "pool = Pooling(3, 3, 2, 0)\n",
    "x = pool.forward(x)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T14:38:46.267613Z",
     "start_time": "2020-03-29T14:38:46.262359Z"
    }
   },
   "source": [
    "# CNN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import pickle\n",
    "\n",
    "\n",
    "class SimpleConvNet:\n",
    "    \"\"\"简单的ConvNet\n",
    "    conv - relu - pool - affine - relu - affine - softmax\n",
    "    Parameters\n",
    "    ----------\n",
    "    input_size : 输入大小（MNIST的情况下为784）\n",
    "    hidden_size_list : 隐藏层的神经元数量的列表（e.g. [100, 100, 100]）\n",
    "    output_size : 输出大小（MNIST的情况下为10）\n",
    "    activation : 'relu' or 'sigmoid'\n",
    "    weight_init_std : 指定权重的标准差（e.g. 0.01）\n",
    "        指定'relu'或'he'的情况下设定“He的初始值”\n",
    "        指定'sigmoid'或'xavier'的情况下设定“Xavier的初始值”\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 input_dim=(1, 28, 28),\n",
    "                 conv_param={\n",
    "                     'filter_num': 30,\n",
    "                     'filter_size': 5,\n",
    "                     'pad': 0,\n",
    "                     'stride': 1\n",
    "                 },\n",
    "                 hidden_size=100,\n",
    "                 output_size=10,\n",
    "                 weight_init_std=0.01):\n",
    "        filter_num = conv_param['filter_num']\n",
    "        filter_size = conv_param['filter_size']\n",
    "        pad = conv_param['pad']\n",
    "        stride = conv_param['stride']\n",
    "        input_size = input_dim[1]\n",
    "        conv_output_size = (input_size - filter_size + 2 * pad) / stride + 1\n",
    "        pool_output_size = int(filter_num * (conv_output_size / 2) *\n",
    "                               (conv_output_size / 2))\n",
    "\n",
    "        self.params = {}\n",
    "        self.params['W1'] = weight_init_std * np.random.randn(\n",
    "            filter_num, input_dim[0], filter_size, filter_size)\n",
    "        self.params['b1'] = np.zeros(filter_num)\n",
    "        self.params['W2'] = weight_init_std * np.random.randn(\n",
    "            pool_output_size, hidden_size)\n",
    "        self.params['b2'] = np.zeros(hidden_size)\n",
    "        self.params['W3'] = weight_init_std * np.ramdom.randn(\n",
    "            hidden_size, output_size)\n",
    "        self.params['b3'] = np.zeros(output_size)\n",
    "\n",
    "        self.layers = OrderedDict()\n",
    "        self.layers['Conv1'] = Convolution(self.params['W1'],\n",
    "                                           self.params['b1'], stride, pad)\n",
    "        self.layers['Relu1'] = Relu()\n",
    "        self.layers['Pool1'] = Pooling(pool_h=2, pool_w=2, stride=2)\n",
    "        self.layers['Affine1'] = Affine(self.params['W2'], self.params['b2'])\n",
    "        self.layers['Relu2'] = Relu()\n",
    "        self.layers['Affine2'] = Affine(self.params['W3'], self.params['b3'])\n",
    "        self.last_layer = SofmaxWithLoss()\n",
    "\n",
    "    def predict(self, x):\n",
    "        for layer in self.layers.values():\n",
    "            x = layer.forward(x)\n",
    "        return x\n",
    "\n",
    "    def loss(self, x, t):\n",
    "        y = self.predict(x)\n",
    "        return self.last_layer.forward(x)\n",
    "\n",
    "    def accuracy(self, x, t, batch_size=100):\n",
    "        if t.ndim != 1:\n",
    "            t = np.argmax(t, axis=1)\n",
    "\n",
    "        acc = 0.0\n",
    "        for i in range(int(x.shape[0] / batch_size)):\n",
    "            tx = x[i * batch_size:(i + 1) * batch_size]\n",
    "            tt = t[i * batch_size:(i + 1) * batch_size]\n",
    "            y = self.predict(tx)\n",
    "            y = np.argmax(y, axis=1)\n",
    "            acc += np.sum(y == t)\n",
    "        return acc / x.shape[0]\n",
    "\n",
    "    def gradient(self, x, t):\n",
    "        self.loss(x, t)\n",
    "        dout = 1\n",
    "        dout = self.last_layer.backward(dout)\n",
    "\n",
    "        grads = {}\n",
    "        grads['W1'], grads['b1'] = self.layers['Conv1'].dW, self.layers[\n",
    "            'Conv1'].db\n",
    "        grads['W2'], grads['b2'] = self.layers['Affine1'].dW, self.layers[\n",
    "            'Affine1'].db\n",
    "        grads['W3'], grads['b3'] = self.layers['Affine2'].dW, self.layers[\n",
    "            'Affine2'].db\n",
    "        return grads\n",
    "\n",
    "    def save_params(self, file_name='params.pkl'):\n",
    "        params = {}\n",
    "        for key, val in self.params.items():\n",
    "            params[key] = val\n",
    "        with open(file_name, 'wb') as f:\n",
    "            pickle.dump(params, f)\n",
    "\n",
    "    def load_params(self, file_name='params.pkl'):\n",
    "        with open(file_name, 'rb') as f:\n",
    "            params = pickle.load(f)\n",
    "        for i, key in enumerate(['Conv1', 'Affine1', 'Affine2']):\n",
    "            self.layers[key].W = self.params['W' + str(i)]\n",
    "            self.layers[key].b = self.params['b' + str(i)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-29T16:13:03.451224Z",
     "start_time": "2020-03-29T16:13:03.444806Z"
    }
   },
   "outputs": [],
   "source": [
    "class Trainer:\n",
    "    def __init__(self,\n",
    "                 network,\n",
    "                 x_train,\n",
    "                 t_train,\n",
    "                 x_test,\n",
    "                 t_test,\n",
    "                 epochs=20,\n",
    "                 mini_batch_size=100,\n",
    "                 optimizer='SGD',\n",
    "                 optimizer_params={'lr': 0.01},\n",
    "                 evaluate_sample_num_per_epoch=None,\n",
    "                 verbose=True):\n",
    "        self.network = network\n",
    "        self.verbose = verbose\n",
    "        self.x_train = x_train\n",
    "        self.t_train = t_train\n",
    "        self.x_test = x_test\n",
    "        self.t_test = t_test\n",
    "        self.epochs = epochs\n",
    "        self.batch_size = mini_batch_size\n",
    "        self.evaluate_sample_num_per_epoch = evaluate_sample_num_per_epoch\n",
    "\n",
    "        optimizer_class_dict = {\n",
    "            'sgd': SGD,\n",
    "            'momentum': Momentum,\n",
    "            'nesterov': Nesterov,\n",
    "            'adagram': AdaGram,\n",
    "            'rmsprpo': RMSprop,\n",
    "            'adam': Adam\n",
    "        }\n",
    "        self.optimizer = optimizer_class_dict[optimizer.lower()](\n",
    "            **optimizer_paramsa)\n",
    "        self.train_size = x_train.shape[0]\n",
    "        self.iter_per_epoch = max(self.train_size / mini_batch_size, 1)\n",
    "        self.max_iter = int(epochs * self.iter_per_epoch)\n",
    "        self.current_iter = 0\n",
    "        self.current_epoch = 0\n",
    "\n",
    "        self.train_loss_list = []\n",
    "        self.train_acc_list = []\n",
    "        self.test_acc_list = []\n",
    "\n",
    "    def train_step(self):\n",
    "        batch_mask = np.random.choice(self.train_size, self.batch_size)\n",
    "        x_batch = x_train[batch_mask]\n",
    "        t_batch = t_train[batch_mask]\n",
    "\n",
    "        grads = self.network.gradient(x_batch, t_batch)\n",
    "        self.optimizer.update(self.network.params, grads)\n",
    "\n",
    "        loss = self.network.loss(x_batch, t_batch)\n",
    "        self.train_loss_list.append(loss)\n",
    "        if self.verbose:\n",
    "            print(\"train loss:\" + str(loss))\n",
    "        if self.current_iter % self.iter_per_epoch == 0:\n",
    "            self.current_epoch += 1\n",
    "\n",
    "            x_train_sample, t_train_sample = self.x_train, self.t_train\n",
    "            x_test_sample, t_test_sample = self.x_test, self.t_test\n",
    "            if not self.evaluate_sample_num_per_epoch is None:\n",
    "                t = self.evaluate_sample_num_per_epoch\n",
    "                x_train_sample, t_train_sample = self.x_train[:\n",
    "                                                              t], self.t_train[:\n",
    "                                                                               t]\n",
    "                x_test_sample, t_test_sample = self.x_test[:t], self.t_test[:t]\n",
    "            train_acc = self.network.accuracy(x_train_sample, t_train_sample)\n",
    "            test_acc = self.network.accuracy(x_test_sample, t_test_sample)\n",
    "            self.train_acc_list.append(train_acc)\n",
    "            self.test_acc_list.append(test_acc)\n",
    "\n",
    "            if self.verbose:\n",
    "                print(\n",
    "                    f\"=== epoch: {str(self.current_epoch)}, train acc: {str(train_acc)}, test acc: {str(test_acc)} ===\"\n",
    "                )\n",
    "        self.current_iter += 1\n",
    "\n",
    "    def train(self):\n",
    "        for i in range(self.max_iter):\n",
    "            self.train_step()\n",
    "\n",
    "        test_acc = self.network.accuracy(self.x_test, self.t_test)\n",
    "        if self.verbose:\n",
    "            print(\"====== Final Test Accuracy =====\")\n",
    "            print(\"test acc: \" + str(test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)\n",
    "\n",
    "max_epochs = 20\n",
    "network = SimpleConvNet()\n",
    "trainer = Trainer(network,\n",
    "                  x_train,\n",
    "                  t_train,\n",
    "                  x_test,\n",
    "                  y_test,\n",
    "                  epochs=max_epochs,\n",
    "                  mini_batch_size=100,\n",
    "                  optimizer='Adam',\n",
    "                  optimizer_param={'lr': 0.001},\n",
    "                  evaluate_sample_num_per_epoch=1000)\n",
    "trainer.train()\n",
    "\n",
    "network.save_params('prams.pkl')\n",
    "\n",
    "markers = {'train': 'o', 'test': 's'}\n",
    "x = np.arange(max_epochs)\n",
    "plt.plot(x, trainer.train_acc_list, marker='o', label='train', markevery=2)\n",
    "plt.plot(x, trainer.test_acc_list, marker='s', label='test', markevery=2)\n",
    "plt.xlabel(\"epochs\")\n",
    "plt.ylabel(\"accuracy\")\n",
    "plt.ylim(0, 1.0)\n",
    "plt.legend(loc='lower right')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 可视化特征图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_show(filters, nx=8, margin=3, scale=10):\n",
    "    FN, C, FH, FW = filters.shape\n",
    "    ny = int(np.ceil(FN / nx))\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(left=0,\n",
    "                        right=1,\n",
    "                        bottom=0,\n",
    "                        top=1,\n",
    "                        hspace=0.05,\n",
    "                        wspace=0.05)\n",
    "    for i in range(FN):\n",
    "        ax = fig.add_subplot(nly, nx, i + 1, xticks=[], yticks=[])\n",
    "        ax.imshow(filters[i, 0], cmap=plt.gray_Lr, interpolation='nearest')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "network = SimpleConvNet()\n",
    "filterL_show(network.params['W1'])\n",
    "\n",
    "network.load_params(\"params.pkl\")\n",
    "filter_show(network.params['W1'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
